"""Bass class for datasets - logic for downloading data (if not available).

   @author
     Victor I. Afolabi
     Artificial Intelligence Expert & Researcher.
     Email: javafolabi@gmail.com
     GitHub: https://github.com/victor-iyiola

   @project
     File: dataset.py
     Package: diagnosis.datasets
     Created on 10 July, 2019 @ 02:24 PM.

   @license
     BSD-3 Clause license.
     Copyright (c) 2019. Victor I. Afolabi. All rights reserved.
"""

import os
from glob import glob
from tqdm import tqdm

import numpy as np
import pandas as pd
import tensorflow as tf
from sklearn.model_selection import train_test_split

__all__ = [
    'create_generator_for_ffn', 'ffn_serialize_fn', 'make_tfrecord',
    'convert_single_example', 'convert_examples_to_features',
    'convert_text_to_feature',
    'create_dataset_for_ffn', 'create_dataset_for_bert',
    'create_generator_for_bert', 'bert_serialize_fn',
    'PaddingInputExample', 'InputExample',
]


SEED = 42


def _float_list_feature(value):
    """Returns a float_list from a float / double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))


def _int64_list_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))


def _int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def create_generator_for_ffn(
        file_list,
        mode='train'):

    # file_list = glob(os.path.join(data_dir, '*.csv'))

    for full_file_path in file_list:
        # full_file_path = os.path.join(data_dir, file_name)
        if not os.path.exists(full_file_path):
            raise FileNotFoundError("File %s not found" % full_file_path)
        df = pd.read_csv(full_file_path, encoding='utf8')

        # so train test split
        if mode == 'train':
            df, _ = train_test_split(df, test_size=0.2, random_state=SEED)
        else:
            _, df = train_test_split(df, test_size=0.2, random_state=SEED)

        for _, row in df.iterrows():
            q_vectors = np.fromstring(row.question_bert.replace(
                '[[', '').replace(']]', ''), sep=' ')
            a_vectors = np.fromstring(row.answer_bert.replace(
                '[[', '').replace(']]', ''), sep=' ')
            vectors = np.stack([q_vectors, a_vectors], axis=0)
            if mode in ['train', 'eval']:
                yield vectors, 1
            else:
                yield vectors


def ffn_serialize_fn(features):
    features_tuple = {'features': _float_list_feature(
        features[0].flatten()), 'labels': _int64_feature(features[1])}
    example_proto = tf.train.Example(
        features=tf.train.Features(feature=features_tuple))
    return example_proto.SerializeToString()


def make_tfrecord(data_dir, generator_fn, serialize_fn, suffix='', **kwargs):
    """Function to make TF Records from csv files
    This function will take all csv files in data_dir, convert them
    to tf example and write to *_{suffix}_train/eval.tfrecord to data_dir.

    Arguments:
        data_dir {str} -- dir that has csv files and store tf record
        generator_fn {fn} -- A function that takes a list of filepath and yield the
        parsed recored from file.
        serialize_fn {fn} -- A function that takes output of generator fn and convert to tf example

    Keyword Arguments:
        suffix {str} -- suffix to add to tf record files (default: {''})
    """
    file_list = glob(os.path.join(data_dir, '*.csv'))
    train_tf_record_file_list = [
        f.replace('.csv', '_{0}_train.tfrecord'.format(suffix)) for f in file_list]
    test_tf_record_file_list = [
        f.replace('.csv', '_{0}_eval.tfrecord'.format(suffix)) for f in file_list]
    for full_file_path, train_tf_record_file_path, test_tf_record_file_path in zip(file_list, train_tf_record_file_list, test_tf_record_file_list):
        print('Converting file {0} to TF Record'.format(full_file_path))
        with tf.io.TFRecordWriter(train_tf_record_file_path) as writer:
            for features in generator_fn([full_file_path], mode='train', **kwargs):
                example = serialize_fn(features)
                writer.write(example)
        with tf.io.TFRecordWriter(test_tf_record_file_path) as writer:
            for features in generator_fn([full_file_path], mode='eval', **kwargs):
                example = serialize_fn(features)
                writer.write(example)


def create_dataset_for_ffn(
        data_dir,
        mode='train',
        hidden_size=768,
        shuffle_buffer=10000,
        prefetch=10000,
        batch_size=32):

    tfrecord_file_list = glob(os.path.join(
        data_dir, '*_FFN_{0}.tfrecord'.format((mode))))
    if not tfrecord_file_list:
        print('TF Record not found')
        make_tfrecord(data_dir, create_generator_for_ffn,
                      ffn_serialize_fn, 'FFN')

    dataset = tf.data.TFRecordDataset(tfrecord_file_list)

    def _parse_ffn_example(example_proto):
        feature_description = {
            'features': tf.io.FixedLenFeature([2*768], tf.float32),
            'labels': tf.io.FixedLenFeature([], tf.int64, default_value=0),
        }
        feature_dict = tf.io.parse_single_example(example_proto,
                                                  feature_description)
        return tf.reshape(feature_dict['features'], (2, 768)), feature_dict['labels']
    dataset = dataset.map(_parse_ffn_example)

    if mode == 'train':
        dataset = dataset.shuffle(shuffle_buffer)

    dataset = dataset.prefetch(prefetch)
    dataset = dataset.batch(batch_size)

    return dataset


class PaddingInputExample(object):
    """Fake example so the num input examples is a multiple of the batch size.
    When running eval/predict on the TPU, we need to pad the number of examples
    to be a multiple of the batch size, because the TPU requires a fixed batch
    size. The alternative is to drop the last batch, which is bad because it means
    the entire output data won't be generated.
    We use this class instead of `None` because treating `None` as padding
    battches could cause silent errors.
    """


class InputExample(object):
    """A single training/test example for simple sequence classification."""

    def __init__(self, guid, text_a, text_b=None, label=None):
        """Constructs a InputExample.
    Args:
      guid: Unique id for the example.
      text_a: string. The untokenized text of the first sequence. For single
        sequence tasks, only this sequence must be specified.
      text_b: (Optional) string. The untokenized text of the second sequence.
        Only must be specified for sequence pair tasks.
      label: (Optional) string. The label of the example. This should be
        specified for train and dev examples, but not for test examples.
    """
        self.guid = guid
        self.text_a = text_a
        self.text_b = text_b
        self.label = label


def convert_single_example(tokenizer, example, max_seq_length=256, dynamic_padding=False):
    """Converts a single `InputExample` into a single `InputFeatures`."""

    if isinstance(example, PaddingInputExample):
        input_ids = [0] * max_seq_length
        input_mask = [0] * max_seq_length
        segment_ids = [0] * max_seq_length
        label = 0
        return input_ids, input_mask, segment_ids, label

    tokens_a = tokenizer.tokenize(example.text_a)
    if len(tokens_a) > max_seq_length - 2:
        tokens_a = tokens_a[0: (max_seq_length - 2)]

    tokens = []
    segment_ids = []
    tokens.append("[CLS]")
    segment_ids.append(0)
    for token in tokens_a:
        tokens.append(token)
        segment_ids.append(0)
    tokens.append("[SEP]")
    segment_ids.append(0)

    input_ids = tokenizer.convert_tokens_to_ids(tokens)

    # The mask has 1 for real tokens and 0 for padding tokens. Only real
    # tokens are attended to.
    input_mask = [1] * len(input_ids)

    # Zero-pad up to the sequence length.
    if not dynamic_padding:
        while len(input_ids) < max_seq_length:
            input_ids.append(0)
            input_mask.append(0)
            segment_ids.append(0)

        assert len(input_ids) == max_seq_length
        assert len(input_mask) == max_seq_length
        assert len(segment_ids) == max_seq_length

    return input_ids, input_mask, segment_ids, example.label


def convert_examples_to_features(tokenizer, examples, max_seq_length=256, dynamic_padding=False):
    """Convert a set of `InputExample`s to a list of `InputFeatures`."""

    input_ids, input_masks, segment_ids, labels = [], [], [], []
    for example in examples:
        input_id, input_mask, segment_id, label = convert_single_example(
            tokenizer, example, max_seq_length, dynamic_padding=dynamic_padding
        )
        input_ids.append(input_id)
        input_masks.append(input_mask)
        segment_ids.append(segment_id)
        labels.append(label)
    return (
        np.squeeze(np.array(input_ids)),
        np.squeeze(np.array(input_masks)),
        np.squeeze(np.array(segment_ids)),
        np.array(labels).reshape(-1, 1),
    )


def convert_text_to_feature(text, tokenizer, max_seq_length, dynamic_padding=False):
    example = InputExample(
        guid=None, text_a=text)
    features = convert_examples_to_features(tokenizer, [example],
                                            max_seq_length,
                                            dynamic_padding=dynamic_padding)
    return features


def create_generator_for_bert(
        file_list,
        tokenizer,
        mode='train',
        max_seq_length=256,
        dynamic_padding=False):
    # file_list = glob(os.path.join(data_dir, '*.csv'))
    for full_file_path in file_list:
        # full_file_path = os.path.join(data_dir, file_name)
        if not os.path.exists(full_file_path):
            raise FileNotFoundError("File %s not found" % full_file_path)

        if os.path.basename(full_file_path) == 'healthtap_data_cleaned.csv':
            df = pd.read_csv(full_file_path, lineterminator='\n')
            df.columns = ['index', 'question', 'answer']
            df.drop(columns=['index'], inplace=True)
        else:
            df = pd.read_csv(full_file_path, lineterminator='\n')

        # so train test split
        if mode == 'train':
            df, _ = train_test_split(df, test_size=0.2, random_state=SEED)
        else:
            _, df = train_test_split(df, test_size=0.2, random_state=SEED)

        for _, row in tqdm(df.iterrows(), total=df.shape[0], desc='Writing to TFRecord'):
            try:
                q_features = convert_text_to_feature(
                    row.question, tokenizer, max_seq_length, dynamic_padding=dynamic_padding)
            except (ValueError, AttributeError):
                continue
            # no labels
            q_features = q_features[:3]
            try:
                a_features = convert_text_to_feature(
                    row.answer, tokenizer, max_seq_length, dynamic_padding=dynamic_padding)
            except (ValueError, AttributeError):
                continue
            a_features = a_features[:3]
            yield (q_features+a_features, 1)


def _qa_ele_to_length(features, labels):
    return tf.shape(features['q_input_ids'])[0] + tf.shape(features['a_input_ids'])[0]


def bert_serialize_fn(features):
    feature, labels = features
    # feature = [_int64_feature(f.flatten()) for f in feature]
    # labels = _int64_feature(labels)
    # features_tuple = (feature, labels)
    features_tuple = {
        'q_input_ids': _int64_list_feature(
            feature[0].flatten()),
        'q_input_masks': _int64_list_feature(
            feature[1].flatten()),
        'q_segment_ids': _int64_list_feature(
            feature[2].flatten()),
        'q_input_shape': _int64_list_feature(
            feature[0].shape),
        'a_input_ids': _int64_list_feature(
            feature[3].flatten()),
        'a_input_masks': _int64_list_feature(
            feature[4].flatten()),
        'a_segment_ids': _int64_list_feature(
            feature[5].flatten()),
        'a_input_shape': _int64_list_feature(
            feature[3].shape),
        'labels': _int64_feature(labels)
    }
    example_proto = tf.train.Example(
        features=tf.train.Features(feature=features_tuple))
    return example_proto.SerializeToString()


def create_dataset_for_bert(
        data_dir,
        tokenizer=None,
        mode='train',
        max_seq_length=256,
        shuffle_buffer=10000,
        prefetch=10000,
        batch_size=32,
        dynamic_padding=False,
        bucket_batch_sizes=[32, 16, 8],
        bucket_boundaries=[64, 128],
        element_length_func=_qa_ele_to_length):

    tfrecord_file_list = glob(os.path.join(
        data_dir, '*_BertFFN_{0}.tfrecord'.format((mode))))
    if not tfrecord_file_list:
        print('TF Record not found')
        make_tfrecord(data_dir, create_generator_for_bert,
                      bert_serialize_fn, 'BertFFN',
                      tokenizer=tokenizer,
                      dynamic_padding=True,
                      max_seq_length=max_seq_length)
        tfrecord_file_list = glob(os.path.join(
            data_dir, '*_BertFFN_{0}.tfrecord'.format((mode))))

    dataset = tf.data.TFRecordDataset(tfrecord_file_list)

    def _parse_bert_example(example_proto):
        feature_description = {
            'q_input_ids': tf.io.VarLenFeature(tf.int64),
            'q_input_masks': tf.io.VarLenFeature(tf.int64),
            'q_segment_ids': tf.io.VarLenFeature(tf.int64),
            'a_input_ids': tf.io.VarLenFeature(tf.int64),
            'a_input_masks': tf.io.VarLenFeature(tf.int64),
            'a_segment_ids': tf.io.VarLenFeature(tf.int64),
            'labels': tf.io.FixedLenFeature([], tf.int64, default_value=0),
        }
        feature_dict = tf.io.parse_single_example(
            example_proto, feature_description)
        dense_feature_dict = {k: tf.sparse.to_dense(
            v) for k, v in feature_dict.items() if k != 'labels'}
        dense_feature_dict['labels'] = feature_dict['labels']
        return dense_feature_dict, feature_dict['labels']
    dataset = dataset.map(_parse_bert_example)

    if mode == 'train':
        dataset = dataset.shuffle(shuffle_buffer)
    if dynamic_padding:
        dataset = dataset.apply(
            tf.data.experimental.bucket_by_sequence_length(
                element_length_func=element_length_func,
                bucket_batch_sizes=bucket_batch_sizes,
                bucket_boundaries=bucket_boundaries
            ))
    else:
        dataset = dataset.batch(batch_size)

    dataset = dataset.prefetch(prefetch)

    return dataset
